r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.

These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""

import torch
import random
import os
from collections import namedtuple
from torch._six import queue
from torch._utils import ExceptionWrapper
from torch.utils.data._utils import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS

from .my_random_resize_crop import MyRandomResizedCrop

__all__ = ['worker_loop']

if IS_WINDOWS:
	import ctypes
	from ctypes.wintypes import DWORD, BOOL, HANDLE


	# On Windows, the parent ID of the worker process remains unchanged when the manager process
	# is gone, and the only way to check it through OS is to let the worker have a process handle
	# of the manager and ask if the process status has changed.
	class ManagerWatchdog(object):
		def __init__(self):
			self.manager_pid = os.getppid()

			self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
			self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
			self.kernel32.OpenProcess.restype = HANDLE
			self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
			self.kernel32.WaitForSingleObject.restype = DWORD

			# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
			SYNCHRONIZE = 0x00100000
			self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)

			if not self.manager_handle:
				raise ctypes.WinError(ctypes.get_last_error())

			self.manager_dead = False

		def is_alive(self):
			if not self.manager_dead:
				# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
				self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
			return not self.manager_dead
else:
	class ManagerWatchdog(object):
		def __init__(self):
			self.manager_pid = os.getppid()
			self.manager_dead = False

		def is_alive(self):
			if not self.manager_dead:
				self.manager_dead = os.getppid() != self.manager_pid
			return not self.manager_dead

_worker_info = None


class WorkerInfo(object):
	__initialized = False

	def __init__(self, **kwargs):
		for k, v in kwargs.items():
			setattr(self, k, v)
		self.__initialized = True

	def __setattr__(self, key, val):
		if self.__initialized:
			raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
		return super(WorkerInfo, self).__setattr__(key, val)


def get_worker_info():
	r"""Returns the information about the current
	:class:`~torch.utils.data.DataLoader` iterator worker process.

	When called in a worker, this returns an object guaranteed to have the
	following attributes:

	* :attr:`id`: the current worker id.
	* :attr:`num_workers`: the total number of workers.
	* :attr:`seed`: the random seed set for the current worker. This value is
	  determined by main process RNG and the worker id. See
	  :class:`~torch.utils.data.DataLoader`'s documentation for more details.
	* :attr:`dataset`: the copy of the dataset object in **this** process. Note
	  that this will be a different object in a different process than the one
	  in the main process.

	When called in the main process, this returns ``None``.

	.. note::
	   When used in a :attr:`worker_init_fn` passed over to
	   :class:`~torch.utils.data.DataLoader`, this method can be useful to
	   set up each worker process differently, for instance, using ``worker_id``
	   to configure the ``dataset`` object to only read a specific fraction of a
	   sharded dataset, or use ``seed`` to seed other libraries used in dataset
	   code (e.g., NumPy).
	"""
	return _worker_info


r"""Dummy class used to signal the end of an IterableDataset"""
_IterableDatasetStopIteration = namedtuple('_IterableDatasetStopIteration', ['worker_id'])


def worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
                auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
                num_workers):
	# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
	# logic of this function.

	try:
		# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
		# module's handlers are executed after Python returns from C low-level
		# handlers, likely when the same fatal signal had already happened
		# again.
		# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
		signal_handling._set_worker_signal_handlers()

		torch.set_num_threads(1)
		random.seed(seed)
		torch.manual_seed(seed)

		global _worker_info
		_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
		                          seed=seed, dataset=dataset)

		from torch.utils.data import _DatasetKind

		init_exception = None

		try:
			if init_fn is not None:
				init_fn(worker_id)

			fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
		except Exception:
			init_exception = ExceptionWrapper(
				where="in DataLoader worker process {}".format(worker_id))

		# When using Iterable mode, some worker can exit earlier than others due
		# to the IterableDataset behaving differently for different workers.
		# When such things happen, an `_IterableDatasetStopIteration` object is
		# sent over to the main process with the ID of this worker, so that the
		# main process won't send more tasks to this worker, and will send
		# `None` to this worker to properly exit it.
		#
		# Note that we cannot set `done_event` from a worker as it is shared
		# among all processes. Instead, we set the `iteration_end` flag to
		# signify that the iterator is exhausted. When either `done_event` or
		# `iteration_end` is set, we skip all processing step and just wait for
		# `None`.
		iteration_end = False

		watchdog = ManagerWatchdog()

		while watchdog.is_alive():
			try:
				r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
			except queue.Empty:
				continue
			if r is None:
				# Received the final signal
				assert done_event.is_set() or iteration_end
				break
			elif done_event.is_set() or iteration_end:
				# `done_event` is set. But I haven't received the final signal
				# (None) yet. I will keep continuing until get it, and skip the
				# processing steps.
				continue
			idx, index = r
			""" Added """
			MyRandomResizedCrop.sample_image_size(idx)
			""" Added """
			if init_exception is not None:
				data = init_exception
				init_exception = None
			else:
				try:
					data = fetcher.fetch(index)
				except Exception as e:
					if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
						data = _IterableDatasetStopIteration(worker_id)
						# Set `iteration_end`
						#   (1) to save future `next(...)` calls, and
						#   (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
						iteration_end = True
					else:
						# It is important that we don't store exc_info in a variable.
						# `ExceptionWrapper` does the correct thing.
						# See NOTE [ Python Traceback Reference Cycle Problem ]
						data = ExceptionWrapper(
							where="in DataLoader worker process {}".format(worker_id))
			data_queue.put((idx, data))
			del data, idx, index, r  # save memory
	except KeyboardInterrupt:
		# Main process will raise KeyboardInterrupt anyways.
		pass
	if done_event.is_set():
		data_queue.cancel_join_thread()
		data_queue.close()
